import argparse
import time
import torch

import model_transformer as model_management
import utils

# %%
parser = argparse.ArgumentParser()
parser.add_argument('--datasets', type=str, default='citation',
                    help='specify multiple datasets in DOUBLE QUOTES separated by commas, e.g., \"cora, dblp\"')
parser.add_argument('--epochs', '-e', type=int, default=2500,
                    help='number of epochs for pre-training')
parser.add_argument('--encoder_dim', type=int, default=256,
                    help='dimension of transformer embeddings')
parser.add_argument('--save', type=bool, default=True,
                    help='whether to save model after pre-training')
parser.add_argument('--adapt', type=bool, default=False,
                    help='whether to train from scratch, or adapt existing model to new dataset (REQUIRES PATH OF'
                         ' EXISTING MODEL')
parser.add_argument('--model_path', type=str, default='',help='path of model (only when adapt is True)' )
args = parser.parse_args()

# %%
start_time = time.time()
data = utils.create_data_structure(
    datasets=args.datasets.split(','),
    # [
    #     'cora', 'dblp', 'cora_ml', 'citeseer', 'pubmed',  # citation full
    #     # 'cora-planetoid', 'pubmed-planetoid', 'citeseer-planetoid', # planetoid
    #     'amazon-computers', 'amazon-photo',  # amazon
    #     # 'ppi',  # PPI
    #     # 'ogbn-arxiv',  # OGBN
    #     # 'ogbl-collab',  # OGBL
    # ],  # args.datasets,
    ssl=True,
    model_type='transformer',
)
num_datasets = len(data.datasets)

# %%
if args.adapt:
    pretrained_model = torch.load(args.model_path)
    stems = pretrained_model['configs']['stems']
    if stems[0]['num_layer_features'] != args.encoder_dim:
        raise(ValueError("The dimension of transformer encoder should be equal to that of saved pre-trained model."))
    stems.append(
        {
            'num_node_f': data.datasets[0].num_node_features,  # input features
            'num_edge_f': [],
            'num_layer_features': args.encoder_dim,  # transformer input features
            'layer_type': 'lin',
            'act': 'relu',
        }
    )
    backbone = pretrained_model['configs']['backbone']
else:
    stems = [
        {
            'num_node_f': dataset.num_features,  # input features
            'num_edge_f': [],
            'num_layer_features': args.encoder_dim,  # transformer input features
            'layer_type': 'lin',
            'act': 'relu',
        }
        for dataset in data.datasets
    ]

    backbone = {
        'num_in_features': stems[0]['num_layer_features'],
        'num_layer_features': args.encoder_dim,  # transformer hidden features
        'num_heads': 8,
        'num_layers': 4,
        'act': 'relu',
    }

ssl_tasks = ['pairdis']

# %%
encoder = model_management.TransformerEncoder(stems=stems, backbone=backbone, num_hops=data.num_hops)
if args.adapt:
    encoder.load_state_dict(pretrained_model['model_state_dict'], strict=False)
    for p in encoder.backbone.parameters():
        p.requires_grad = False
    for p in encoder.stems.parameters():
        p.requires_grad = False
    for p in encoder.stems[-1].parameters():
        p.requires_grad = True

ssl_transformer = model_management.SSLTransformer(encoder=encoder, data=data, ssl_tasks=ssl_tasks)
del encoder

if args.save:
    model_name = 'ssl_transformer_' + str(num_datasets) + '_datasets_neighbourhood_aggregation.pt'
    save_path = utils.create_save_path(model_name)
else:
    save_path = None
ssl_transformer.pretrain(num_epochs=args.epochs, save_path=save_path)
